from asyncio import Queue
from enum import Enum
import sys, os
from typing import AsyncIterator, Dict, List, Optional, Tuple

import torch

from ktransformers.server.config.log import logger
from ktransformers.server.crud.assistants.assistants import AssistantDatabaseManager
from ktransformers.server.crud.assistants.messages import MessageDatabaseManager
from ktransformers.server.crud.assistants.runs import RunsDatabaseManager
from ktransformers.server.crud.assistants.threads import ThreadsDatabaseManager
from ktransformers.server.exceptions import request_error
from ktransformers.server.schemas.assistants.assistants import AssistantObject
from ktransformers.server.schemas.assistants.messages import MessageCreate, MessageObject, Role
from ktransformers.server.schemas.assistants.runs import RunObject
from ktransformers.server.schemas.assistants.threads import ThreadObject
from ktransformers.server.schemas.base import ObjectID, Order
from ktransformers.server.utils.multi_timer import Profiler


from .args import ConfigArgs,default_args



class BackendInterfaceBase:
    '''
    Interface to inference frameworks. e.g. transformers, exllama.
    Implement __init__ and work  
    '''

    args: ConfigArgs
    profiler:Profiler = Profiler()

    def __init__(self, args:ConfigArgs = default_args):
        raise NotImplementedError

    
    async def inference(self,local_messages,request_unique_id:Optional[str])->AsyncIterator[str]:
        '''
        work can be called directly, or by ThreadContext

        local_messages: 
            When called by ThreadContext, local_messages are generated by ThreadContext.get_local_messages().
            Please deal with different local_messages
        request_unique_id:
            unique id of different requests, useful when using cache
        
        return:
            async str output for stream update

        '''
        raise NotImplementedError


    def report_last_time_performance(self):
        try:
            tokenize_time = self.profiler.get_timer_sec('tokenize')
            prefill_time = self.profiler.get_timer_sec('prefill')
            decode_time = self.profiler.get_timer_sec('decode')
            prefill_count = self.profiler.get_counter('prefill')
            decode_count = self.profiler.get_counter('decode')

            logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
        except:
            logger.info(f'Performance statistics not recorded')


class ThreadContext:
    '''
    A thread context holding assistant logics 
    
    '''

    args: ConfigArgs
    # Assistant Logic
    assistant: Optional[AssistantObject] = None
    related_threads : List[ThreadObject]
    thread: ThreadObject
    messages: List[MessageObject] = [] 
    run: RunObject

    interface: Optional[BackendInterfaceBase] = None
     
    queue: Optional[Queue] = None
    timer: Profiler = Profiler()

    def __init__(self, run: RunObject,interface:BackendInterfaceBase, args: ConfigArgs = default_args) -> None:
        self.args = args
        self.thread_manager = ThreadsDatabaseManager()
        self.message_manager = MessageDatabaseManager()
        self.runs_manager = RunsDatabaseManager()
        self.assistant_manager = AssistantDatabaseManager()
        self.thread = self.thread_manager.db_get_thread_by_id(run.thread_id)
        self.assistant = self.assistant_manager.db_get_assistant_by_id(run.assistant_id)
        self.messages = self.message_manager.db_list_messages_of_thread(run.thread_id,order=Order.ASC)
        logger.debug(f"{len(self.messages)} messages loaded from database")
        self.interface = interface
        self.update_by_run(run,args)

    def get_local_messages(self):
        '''
        Get local messages, as the input to interface.work
        This function is intended to message preprocess e.g. apply chat template
        '''
        raise NotImplementedError

    def update_by_run(self,run:RunObject,args:ConfigArgs = default_args):
        self.run = run 
        self.args = args
       
    def put_user_message(self, message: MessageObject):
        assert (
            message.role.is_user() and message.thread_id == self.thread.id and message.status == MessageObject.Status.in_progress
        )
        self.messages.append(message)

    def delete_user_message(self,message_id: ObjectID):
        self.messages = [m for m in self.messages if m.id != message_id]

    async def work(self)->AsyncIterator:
        logger.debug('start working')
        user_message = self.messages[-1]
        if not user_message.role.is_user():
            raise request_error('user must talk before LLM can talk')
        user_message.status = MessageObject.Status.completed
        user_message.sync_db()

        local_messages = self.get_local_messages() # must get this before we interseted reply_message


        response_str_count = 0  
        reply_message = self.message_manager.create_message_object(
                            self.thread.id,
                            self.run.id,
                            MessageCreate(role=Role.assistant, content=""),    
                        )
        reply_message.assistant_id = self.assistant.id
        self.messages.append(reply_message) 

        yield reply_message.stream_response_with_event(MessageObject.Status.created)
        yield reply_message.stream_response_with_event(MessageObject.Status.in_progress)
        yield self.run.stream_response_with_event(RunObject.Status.in_progress)

        async for token in self.interface.inference(local_messages,self.thread.id):     
            if self.run.status == RunObject.Status.cancelling:
                logger.warn(f'Run {self.run.id} cancelling')
                break
            yield reply_message.append_message_delta(token)
            response_str_count+=1
        
        if self.run.status == RunObject.Status.cancelling:
            yield self.run.stream_response_with_event(RunObject.Status.cancelled)
            yield reply_message.stream_response_with_event(MessageObject.Status.incomplete)
        elif self.run.status == RunObject.Status.in_progress:
            yield self.run.stream_response_with_event(RunObject.Status.completed)
            yield reply_message.stream_response_with_event(MessageObject.Status.completed)
        else:
            raise NotImplementedError(f'{self.run.status} should not appear here')

        reply_message.sync_db()
        self.run.sync_db()